{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MLX\n",
    "\n",
    "This notebook shows how to get started using `MLX` LLM's as chat models.\n",
    "\n",
    "In particular, we will:\n",
    "1. Utilize the [MLXPipeline](https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/llms/mlx_pipeline.py), \n",
    "2. Utilize the `ChatMLX` class to enable any of these LLMs to interface with LangChain's [Chat Messages](https://python.langchain.com/docs/modules/model_io/chat/#messages) abstraction.\n",
    "3. Demonstrate how to use an open-source LLM to power an `ChatAgent` pipeline\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install --upgrade --quiet  mlx-lm transformers huggingface_hub"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Instantiate an LLM\n",
    "\n",
    "There are three LLM options to choose from."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.llms.mlx_pipeline import MLXPipeline\n",
    "\n",
    "llm = MLXPipeline.from_model_id(\n",
    "    \"mlx-community/quantized-gemma-2b-it\",\n",
    "    pipeline_kwargs={\"max_tokens\": 10, \"temp\": 0.1},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Instantiate the `ChatMLX` to apply chat templates"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instantiate the chat model and some messages to pass."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.chat_models.mlx import ChatMLX\n",
    "from langchain_core.messages import HumanMessage\n",
    "\n",
    "messages = [\n",
    "    HumanMessage(\n",
    "        content=\"What happens when an unstoppable force meets an immovable object?\"\n",
    "    ),\n",
    "]\n",
    "\n",
    "chat_model = ChatMLX(llm=llm)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Inspect how the chat messages are formatted for the LLM call."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chat_model._to_chat_prompt(messages)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Call the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = chat_model.invoke(messages)\n",
    "print(res.content)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Take it for a spin as an agent!\n",
    "\n",
    "Here we'll test out `gemma-2b-it` as a zero-shot `ReAct` Agent. The example below is taken from [here](https://python.langchain.com/docs/modules/agents/agent_types/react#using-chat-models).\n",
    "\n",
    "> Note: To run this section, you'll need to have a [SerpAPI Token](https://serpapi.com/) saved as an environment variable: `SERPAPI_API_KEY`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain import hub\n",
    "from langchain.agents import AgentExecutor, load_tools\n",
    "from langchain.agents.format_scratchpad import format_log_to_str\n",
    "from langchain.agents.output_parsers import (\n",
    "    ReActJsonSingleInputOutputParser,\n",
    ")\n",
    "from langchain.tools.render import render_text_description\n",
    "from langchain_community.utilities import SerpAPIWrapper"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Configure the agent with a `react-json` style prompt and access to a search engine and calculator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup tools\n",
    "tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm)\n",
    "\n",
    "# setup ReAct style prompt\n",
    "prompt = hub.pull(\"hwchase17/react-json\")\n",
    "prompt = prompt.partial(\n",
    "    tools=render_text_description(tools),\n",
    "    tool_names=\", \".join([t.name for t in tools]),\n",
    ")\n",
    "\n",
    "# define the agent\n",
    "chat_model_with_stop = chat_model.bind(stop=[\"\\nObservation\"])\n",
    "agent = (\n",
    "    {\n",
    "        \"input\": lambda x: x[\"input\"],\n",
    "        \"agent_scratchpad\": lambda x: format_log_to_str(x[\"intermediate_steps\"]),\n",
    "    }\n",
    "    | prompt\n",
    "    | chat_model_with_stop\n",
    "    | ReActJsonSingleInputOutputParser()\n",
    ")\n",
    "\n",
    "# instantiate AgentExecutor\n",
    "agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent_executor.invoke(\n",
    "    {\n",
    "        \"input\": \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\"\n",
    "    }\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
